clear variables

path = 'parameters';
name = 'replication';
datapath = 'Output\';
filename = 'Chains.xlsx';
    
for shock_negative = 1:5
       
    if shock_negative == 1 
        run_spec = '_27';
    elseif shock_negative == 2 
        run_spec = '_10';
    elseif shock_negative == 3 
        run_spec = '_38';
    elseif shock_negative == 4
        run_spec = '_17';
    elseif shock_negative == 5
        run_spec = '_22';
    end
    parampath = [path,run_spec];
    load([parampath, '\aux_transition.mat']);
    
    aux_old=aux;
    clear aux
    if shock_negative == 1 || shock_negative == 5
        aux.n_m    = 1600;
    else        
        aux.n_m    = 800;
    end

    aux.N_T    = length(aux_old.theta);
    aux.pen    = 5*10^(5);
    % Set tolerance
    aux.tol_HJB= 10^(-7);
    aux.tol_Q  = 10^(-7);
    aux.tol_G  = 10^(-9);
    aux.CN     = 0.5;
    aux.T      = 12*7;
    
    load([parampath, '\param_ini_', name, '.mat']);
    load([parampath, '\param_end_', name, '.mat']);
    param = param_ini;

    n=1;
    scale=3.5;
    tmp_n1=22*10^(-2).*aux.N_T;
    tmp_t_grid1=(exp(scale.*linspace(0,1,n+(aux.N_T-tmp_n1)))-1)./(exp(scale)-1).*aux.T/2;
    tmp_t_grid2=linspace(aux.T/2,aux.T,tmp_n1+1);
    aux.t_vec=[tmp_t_grid1(1+n:end),tmp_t_grid2(2:end)];
    
    clear tmp_*
    aux.t_lam=aux.t_vec;
    
    a = log( param_end.m_l);
    c = log( param_end.m_e);   
    b = max( log( param_ini.m_l), log( param_end.m_l));
    d = max( log( param_ini.m_e), log( param_end.m_e));
    
    
    adj=1.5;
    adj1=0.5;
    if shock_negative == 3 
        adj  = 0.1;
        adj1 = 0.1;
    end
    
    tmp1 = grid_ini_fun( [a - 0.20, a - 0.01], adj, aux.n_m/40, -1)';
    tmp2 = grid_ini_fun( [a - 0.01, a], adj1, aux.n_m/8+1,-1)';
    tmp3 = grid_ini_fun( [ a, a + 0.01], adj1, aux.n_m/8+1,1)';
    tmp4 = grid_ini_fun( [a + 0.01, ((c - 0.01)+ (a + 0.01))/2],  adj, (aux.n_m-2.*(aux.n_m/4+aux.n_m/40))/2+1,1)';
    tmp5 = grid_ini_fun( [((c - 0.01)+(a + 0.01))/2, c - 0.01],  adj, (aux.n_m-2.*(aux.n_m/4+aux.n_m/40))/2+1,-1)';
    tmp6 = grid_ini_fun( [c - 0.01, c], adj1, aux.n_m/8+1,-1)';
    tmp7 = grid_ini_fun( [c, c + 0.01], adj1, aux.n_m/8+1,1)';
    tmp8 = grid_ini_fun( [c + 0.01, d + 0.20,], adj, aux.n_m/40+1, 1)';
    param.ln_m_HJB = [tmp1; tmp2(2:end); tmp3(2:end); tmp4(2:end);tmp5(2:end);tmp6(2:end);tmp7(2:end);tmp8(2:end);];
        
    J_end=J_fun(max(param.m_l,min(exp(param.ln_m_HJB),param.m_u)),param_end);
    J_end_analy = J_end;
    J_ini=J_fun(max(param.m_l,min(exp(param.ln_m_HJB),param.m_u)),param_ini);
    J_ini_analy = J_ini;
    Q_ini=q_fun(min(param.m_u,max(param.m_l,exp(ln_m_KFE([param.m_l,param.m_u],1,aux)))),param);
    Q_ini=Q_ini./(Q_ini(1).*param.s+Q_ini(end)-Q_ini(1));
    Q_ini_analy = Q_ini;
    Q_end=q_fun(min(param_end.m_u,max(param_end.m_l,exp(ln_m_KFE([param_end.m_l,param_end.m_u],1,aux)))),param_end);
    Q_end=Q_end./(Q_end(1).*param.s+Q_end(end)-Q_end(1));
    Q_end_analy = Q_end;
    
    aux.dt=0.05;
    aux.n=2000;
    Q_tmp=Q_ini;
    Q_ini = KFE_t([param_ini.m_l,param_ini.m_u],Q_ini, delta_fun( exp( ln_m_KFE( [param_ini.m_l, param_ini.m_u], 1, aux)), param_ini), param_ini, aux);
    Q_end = KFE_t([param_end.m_l,param_end.m_u],Q_end, delta_fun( exp( ln_m_KFE( [param_end.m_l, param_end.m_u], 1, aux)), param_end), param_end, aux);
    
    T=speye(aux.n_m);
    [T_p, T_l_p, T_u_p] = M_p(param.ln_m_HJB,ones(aux.n_m,1));
    [T_pp] = M_pp(param.ln_m_HJB);
    param.T=param.r.*T-(param.mu-param.sigma.^2./2).*T_p-param.sigma.^2./2.*T_pp;
    
    
    Q_HJB_ini=interp1( ln_m_KFE( [ param_ini.m_l, param_ini.m_u], 1, aux), Q_ini, param.ln_m_HJB,'pchip','extrap');        
    Q_HJB_ini(param.ln_m_HJB<log(param_ini.m_l))=Q_ini(1);
    Q_HJB_ini(param.ln_m_HJB>log(param_ini.m_u))=Q_ini(end);
    
    Q_HJB_end=interp1( ln_m_KFE( [ param_end.m_l, param_end.m_u], 1, aux), Q_end, param.ln_m_HJB,'pchip','extrap');        
    Q_HJB_end(param.ln_m_HJB<log(param_end.m_l))=Q_end(1);
    Q_HJB_end(param.ln_m_HJB>log(param_end.m_u))=Q_end(end);
    
    param_ini.ln_m_HJB=param.ln_m_HJB;
    param_end.ln_m_HJB=param.ln_m_HJB;
    param_ini.T=param.T;
    param_end.T=param.T;
    
    for i=1:aux.n
        J_ini=HJB_t( J_ini, Q_HJB_ini, param_ini, aux);
        J_end=HJB_t( J_end, Q_HJB_end, param_end, aux);
    end
    
    J_end=min(param.c+param.K, max(0,J_end));
    J_ini=min(param.c+param.K, max(0,J_ini));
    
    q=Q_end(2:end)-Q_end(1:end-1);
    ln_m_1=ln_m_KFE([param_end.m_l,param_end.m_u],1,aux);
    param.mean=[q'*exp(ln_m_1(2:end)./(1-param.alpha))];
    param.mean
    
    
    %% Check difference between analytical solutions and solutions on grid
    figure(1)
    plot(exp(param.ln_m_HJB),J_ini,'b')
    hold on
    plot(exp(param.ln_m_HJB),J_ini_analy,'k--')
    plot(exp(param.ln_m_HJB),J_end,'r')
    plot(exp(param.ln_m_HJB),J_end_analy,'g--')
    hold off
    
    figure(2)
    plot(Q_ini,'b')
    hold on
    plot(Q_ini_analy,'k--')
    plot(Q_end,'r')
    plot(Q_end_analy,'g--')
    hold off
    
    m_grid_HJB=exp(param.ln_m_HJB);
    m_grid_FPE=exp(ln_m_KFE([param.m_l,param.m_u],1,aux));
    T = table(m_grid_HJB, J_ini_analy,J_ini,m_grid_FPE,Q_ini_analy,Q_ini);
    writetable(T, [datapath,filename], 'Sheet', ['nume_acc',run_spec], 'Range', 'A1')
    
    q=[];
    ln_m_1=[];
    % Assign boundaries at initial and end steady state
    param.m_l_ini=param_ini.m_l;
    param.m_u_ini=param_ini.m_u;
    
    param.m_l_end=param_end.m_l;
    param.m_h_end=param_end.m_h;
    param.m_e_end=param_end.m_e;   
    param.m_u_end=param_end.m_u;
    
    % Make initial guess of Q
    Q_t=repmat(Q_end',length(aux.t_vec),1);
    
    % Assign new value of omega_0 
    param.omega_0=param_end.omega_0;


    Q_start = Q_ini;% Assign distribution at initial steady state
    J_final = J_end;% Assign value at new steady state    
    
    param_start=param_ini;
    param_final=param_end;
    
    % Guess for transition path for lambda
    theta=interp1(aux_old.t_vec,aux_old.theta,aux.t_vec,'linear');
    theta((aux.t_vec>aux_old.t_vec(end)))=aux_old.theta(end);
    theta((aux.t_vec<aux_old.t_vec(1)))=aux_old.theta(1);
    
    param.dist_m_grid=linspace(min(param_ini.m_l,param_end.m_l)-0.01,max(param_ini.m_u,param_end.m_u)+0.01);
    
    
    % Set (sluggish) updating rule for lambda
    adj     = 2;
    adj2    = 0.05;
    if shock_negative == 3
        adj     = 1;
        adj2    = 0.08;
    end
    adj1    = exp(-adj2.*aux.t_lam);
    diff0   = 1;
    tol     = 10^(-5);
    dt      = [aux.t_vec(1),aux.t_vec(2:end)-aux.t_vec(1:end-1)];
   
    
    A=nan(1);
    
    while diff0>tol       
        diff_t = transition_t_new(theta,J_final,Q_start, A, param,aux,1);
        theta_old=theta;
        A=[diff_t.m_l_t; diff_t.m_h_t; diff_t.m_e_t];
       
        diff = diff_t.excess;
        diff0=max(abs(diff));
        if diff0>tol
            theta = theta+adj.*max(-0.4,min(0.4,diff)).*adj1;
            theta = max(theta,0.0001);
        end
        100*[diff(1:5),diff(end), diff0, 100.*sum(abs(diff))/length(aux.t_lam)]
    end
    
    clear diff
    aux.T_fig=12*5;
    
    diff.t_grid= linspace(0,1,1001).*aux.T_fig;
    diff.t_grid= diff.t_grid(2:end);
    diff.u=interp1(aux.t_vec,diff_t.Q_t(:,1).*param.s,diff.t_grid,'linear','extrap');
    diff.lambda=interp1(aux.t_vec,diff_t.lambda,diff.t_grid,'linear','extrap');
    diff.ete=interp1(aux.t_vec,diff_t.ete,diff.t_grid,'linear','extrap');
    diff.mean_m=interp1(aux.t_vec,diff_t.mean_m,diff.t_grid,'linear','extrap');
    diff.hires_tot=interp1(aux.t_vec,diff_t.hires_tot,diff.t_grid,'linear','extrap');
    diff.G_r=interp1(aux.t_vec,diff_t.G_r,diff.t_grid,'linear','extrap');
    diff.new_chains=interp1(aux.t_vec,diff_t.new_chains,diff.t_grid,'linear','extrap');
    
    diff.m_l_t=interp1(aux.t_vec,diff_t.m_l_t,diff.t_grid,'linear','extrap');
    diff.m_h_t=interp1(aux.t_vec,diff_t.m_h_t,diff.t_grid,'linear','extrap');
    diff.m_e_t=interp1(aux.t_vec,diff_t.m_e_t,diff.t_grid,'linear','extrap');
    diff.m_u_t=interp1(aux.t_vec,diff_t.m_u_t,diff.t_grid,'linear','extrap');
    
    %% Flows
    u_ini = Q_start(1).*param.s;
    u = [u_ini, diff.u];
    lambda=[param_start.lambda, diff.lambda];
    ete_ini = jtj_num( delta_fun( exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), param_start), Q_start);
    hires_ini = ete_ini.*(1-u(1))+ param_start.lambda.*u(1);
    ete = [ ete_ini; ete_ini; diff.ete'];
    hires = [hires_ini; hires_ini;  diff.hires_tot'];
    
    m_l_t = [ param_start.m_l; param_start.m_l;  diff.m_l_t'];
    m_h_t = [ param_start.m_h; param_start.m_h;  diff.m_h_t'];
    m_e_t = [ param_start.m_e; param_start.m_e;  diff.m_e_t'];
    m_u_t = [ param_start.m_u; param_start.m_u;  diff.m_u_t'];
    
    G_r_ini = interp1( exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), (Q_start-Q_start(1))./(Q_start(end)-Q_start(1)), param_start.m_h, 'linear');
    new_chains_ini = param.s.*param_start.lambda.*G_r_ini.*(1-u(1)) + u(1).*param_start.lambda;
    new_chains = [new_chains_ini; new_chains_ini; diff.new_chains'];
    G_r = [G_r_ini; G_r_ini; diff.G_r'];
    
    eu = (u(2:end) - u(1:end-1))./[diff.t_grid(1), diff.t_grid(2:end)-diff.t_grid(1:end-1)]./(1-u(2:end)) +  diff.lambda.*u(2:end)./(1-u(2:end));
    t_grid=[-12; -0.00001; diff.t_grid'];
    u=[u(1),u]';
    eu_ini = lambda(1).*u_ini(1)./(1-u_ini(1));
    eu=[eu_ini;eu_ini;eu'];
    mean_m =  mean_t_fun(exp(ln_m_KFE([param_start.m_l,param_start.m_u],1,aux)), Q_start);
    

    %% time aggregated 
    n_agg = 200;
    t_grid_agg = 0:aux.T_fig;
    t_grid_tmp = linspace(0,1,aux.T_fig*n_agg+1).*aux.T_fig;
    lambda_tmp = interp1( aux.t_vec, diff_t.lambda, t_grid_tmp,'linear','extrap');
    ete_tmp = interp1( aux.t_vec, diff_t.ete, t_grid_tmp,'linear','extrap');
    hires_tmp = interp1( aux.t_vec, diff_t.hires_tot, t_grid_tmp,'linear','extrap');
    new_chains_tmp = interp1( aux.t_vec, diff_t.new_chains', t_grid_tmp,'linear','extrap');
    
    u_tmp = interp1( aux.t_vec, diff_t.Q_t(:,1).*param.s, t_grid_tmp,'linear','extrap');
    u_tmp(1)   = Q_start(1).*param.s; % measure u(0) prior to shock to get eu right
    u_agg      = interp1( aux.t_vec, diff_t.Q_t(:,1).*param.s, t_grid_agg,'linear','extrap');
    u_agg(1)   = Q_start(1).*param.s; % measure u(0) prior to shock to get eu right

    new_chains_agg = sum(reshape(new_chains_tmp(1:end-1)/n_agg, n_agg, aux.T_fig));
    hires_agg = sum(reshape(hires_tmp(1:end-1)/n_agg, n_agg, aux.T_fig));
    ete_agg = sum(reshape(ete_tmp(1:end-1).*(1-u_tmp(1:end-1))./n_agg, n_agg, aux.T_fig))./(1-u_agg(1:end-1));
    lambda_agg = sum(reshape(lambda_tmp(1:end-1)/n_agg.*u_tmp(1:end-1), n_agg, aux.T_fig))./u_agg(1:end-1);
    eu_tmp = (u_tmp(2:end) - u_tmp(1:end-1))./(t_grid_tmp(2:end) - t_grid_tmp(1:end-1))./(1-u_tmp(1:end-1)) +  lambda_tmp(1:end-1).*u_tmp(1:end-1)./(1-u_tmp(1:end-1));
    eu_agg = sum(reshape(eu_tmp/n_agg.*(1-u_tmp(1:end-1)), n_agg, aux.T_fig)./(1-u_agg(1:end-1)));
    
    varNames = {'Time_months', ...
        'Unemployment', 'UE', 'EU', 'ete', 'total_hires', 'new_chains'};

    if shock_negative == 5
        sheet_name = 'Figure 8';
    else
        sheet_name = ['transition',run_spec];
    end

    T = table([-12;t_grid_agg'], ...
        [u_ini,u_ini,u_agg(2:end)]', [ param_start.lambda, param_start.lambda, lambda_agg]', [eu_ini,eu_ini, eu_agg]', [ete_ini, ete_ini, ete_agg]', [hires_ini,hires_ini,hires_agg]' , [new_chains_ini,new_chains_ini,new_chains_agg]', ...
        'VariableNames', varNames);
   
    writetable(T, [datapath,filename], 'Sheet', sheet_name, 'Range', 'Q1')
    
    

    varNames = {'Time_months', 'Productivity', ...
        'Output', 'Average_m', ...
        'Unemployment', 'UE', 'EU', 'ete', 'total_hires', 'new_chains', 'G_m_h', 'm_l', 'm_h', 'm_e', 'm_u'};
    
    T = table(t_grid, [1; 1; (1-param.shock).*ones(size(diff.t_grid'))], ...
        [mean_m; mean_m; diff.mean_m']./param.alpha.*(1-u),[mean_m; mean_m; diff.mean_m'], ...
        u, [ lambda(1); lambda'], eu, ete, hires, new_chains, G_r, m_l_t, m_h_t, m_e_t, m_u_t, ...
        'VariableNames', varNames);
   
    writetable(T, [datapath,filename], 'Sheet', sheet_name, 'Range', 'A1')
    
    
    aux.diff=diff;
    aux.theta=theta;
    save([parampath,'\aux_transition.mat'], 'aux');   
    
    if shock_negative == 5
        t_grid = 0.25:0.25:aux.T_fig;
        Q=interp1( aux.t_vec, diff_t.q_dist, t_grid);
        Q=Q./repmat(Q(:,end),1,length(Q(1,:)));
        delta=interp1( aux.t_vec, diff_t.delta_dist, t_grid);
        
        T = table([[nan(1), t_grid]; ...
            param.dist_m_grid', Q'] ...
                );
    
        writetable(T, [datapath,filename], 'Sheet', 'transition_q', 'Range', 'A2')
        
        T = table([[nan(1), t_grid]; ...
            param.dist_m_grid', delta' ...
                ]);
        
        writetable(T, [datapath,filename], 'Sheet', 'transition_delta', 'Range', 'A2')
    end    

end
